""" Dictionary learning.
"""
# Author: Vlad Niculae, Gael Varoquaux, Alexandre Gramfort
# License: BSD 3 clause
# Modefied from sklearn.decomposition.Matrix Decomposition
# https://scikit-learn.org/stable/modules/classes.html#module-sklearn.decomposition

import time
import sys
import itertools

from math import ceil

import numpy as np
from scipy import linalg
from joblib import Parallel, effective_n_jobs

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import deprecated
from sklearn.utils import (check_array, check_random_state, gen_even_slices,
                     gen_batches)
from sklearn.utils.extmath import randomized_svd, row_norms
from sklearn.utils.validation import check_is_fitted, _deprecate_positional_args
# from sklearn.utils.fixes import delayed
from sklearn.linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars

# custom define K-SVD Class Requirement
from sklearn import linear_model
from collections import defaultdict
import inspect
from sklearn.decomposition import sparse_encode




import os
import numpy as np
from sklearn import linear_model
import scipy.misc
from matplotlib import pyplot as plt
import scipy.io as scio





class KSVD(object):
 

    def __init__(self, n_components, max_iter=30, tol=1e-6,n_nonzero_coefs=None,transform_algorithm='omp', 
            transform_max_iter=1000) -> None:
        """
        稀疏模型Y = DX，Y为样本矩阵，使用KSVD动态更新字典矩阵D和稀疏矩阵X
        :param n_components: 字典所含原子个数（字典的列数）
        :param max_iter: 最大迭代次数，OMP算法
        :param tol: 稀疏表示结果的容差
        :param n_nonzero_coefs: 稀疏度
        """
        super().__init__()
        self.dictionary = None
        self.sparsecode = None
        self.max_iter = max_iter
        self.tol = tol
        self.n_components = n_components
        self.n_nonzero_coefs = n_nonzero_coefs
        self. transform_algorithm = transform_algorithm
        self.transform_max_iter = transform_max_iter
        self.loss = 0

    @classmethod
    def _get_param_names(cls):
        """Get parameter names for the model"""
        # fetch the constructor or the original constructor before
        # deprecation wrapping if any
        init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
        if init is object.__init__:
            # No explicit constructor to introspect
            return []

        # introspect the constructor arguments to find the model parameters
        # to represent
        init_signature = inspect.signature(init)
        # Consider the constructor parameters excluding 'self'
        parameters = [p for p in init_signature.parameters.values()
                      if p.name != 'self' and p.kind != p.VAR_KEYWORD]
        for p in parameters:
            if p.kind == p.VAR_POSITIONAL:
                raise RuntimeError("scikit-learn estimators should always "
                                   "specify their parameters in the signature"
                                   " of their __init__ (no varargs)."
                                   " %s with constructor %s doesn't "
                                   " follow this convention."
                                   % (cls, init_signature))
        # Extract and sort argument names excluding 'self'
        # print("len(parameters):", len(parameters))
        return sorted([p.name for p in parameters])

    # 获取参数
    def get_params(self, deep=True):
        """
        Parameters
        ----------
        deep : bool, default=True
            If True, will return the parameters for this estimator and
            contained subobjects that are estimators.

        Returns
        -------
        params : mapping of string to any
            Parameter names mapped to their values.
        """
        out = dict()
        for key in self._get_param_names():
            try:
                # print("key:", key)
                value = getattr(self, key)
                # print("value", value)
            except AttributeError:
                warnings.warn('check the parameters',
                              FutureWarning)
                value = None
            if deep and hasattr(value, 'get_params'):
                deep_items = value.get_params().items()
                out.update((key + '__' + k, val) for k, val in deep_items)
            out[key] = value
        return out

    # 设置参数

    def set_params(self, **params):
        """
        Parameters
        ----------
        **params : dict
        Returns
        -------
        self : object
        """
        if not params:
            # Simple optimization to gain speed (inspect is slow)
            return self
        valid_params = self.get_params(deep=True)

        nested_params = defaultdict(dict)  # grouped by prefix
        for key, value in params.items():
            key, delim, sub_key = key.partition('__')
            if key not in valid_params:
                raise ValueError('Invalid parameter %s for estimator %s. '
                                 'Check the list of available parameters '
                                 'with `estimator.get_params().keys()`.' %
                                 (key, self))
            if delim:
                nested_params[key][sub_key] = value
            else:
                setattr(self, key, value)
                valid_params[key] = value

        for key, sub_params in nested_params.items():
            valid_params[key].set_params(**sub_params)

        return self


    def _initialize(self, y):
        """
        初始化字典矩阵
        """
        # u, s, v = np.linalg.svd(y)
        u, s, v = linalg.svd(y)
        self.dictionary = u[:, :self.n_components]

    def _update_dict(self, y, d, x):
        """
        使用KSVD更新字典的过程
        """
        # for i in range(self.n_components):
        #     index = np.nonzero(x[i, :])[0]
        #     if len(index) == 0:
        #         continue

        #     d[:, i] = 0
        #     r = (y - np.dot(d, x))[:, index]
        #     u, s, v = linalg.svd(r, full_matrices=False)
        #     d[:, i] = u[:, 0].T
        #     x[i, index] = s[0] * v[0, :]

        for i in range(self.n_components):
            index = np.nonzero(x[i, :])[0]
            if len(index) == 0:
                continue
            # 更新第i列
            d[:, i] = 0
            # 计算误差矩阵
            r = (y - np.dot(d, x))[:, index]
            # 利用svd的方法，来求解更新字典和稀疏系数矩阵
            u, s, v = np.linalg.svd(r, full_matrices=False)
            # 使用左奇异矩阵的第0列更新字典
            d[:, i] = u[:, 0]
            # 使用第0个奇异值和右奇异矩阵的第0行的乘积更新稀疏系数矩阵
            for j,k in enumerate(index):
                x[i, k] = s[0] * v[0, j]
        return d, x

    def fit(self, y, para_alpha=1):

        # KSVD迭代过程

        # random_state = check_random_state(self.random_state)
        # y = self._validate_data(y)
        self._initialize(y)
        x = None
        alpha = para_alpha
        e_loss = []
        for i in range(self.max_iter):
            if self.transform_algorithm == "omp":
                # x = linear_model.orthogonal_mp(self.dictionary, y, n_nonzero_coefs=self.n_nonzero_coefs)
                x = linear_model.orthogonal_mp(self.dictionary, y, n_nonzero_coefs=self.n_nonzero_coefs)
            elif self.transform_algorithm == "lars":
                _,  x, _,= linear_model.lars_path(X = self.dictionary, y=y, Gram= self.dictionary.T.dot(self.dictionary), method='lar')
            elif self.transform_algorithm == "lasso":
                # _alphas, _active, x = linear_model.lars_path_gram(Xy = np.dot(self.dictionary.T, y), Gram = np.dot(self.dictionary.T , self.dictionary), n_samples=len(y), alpha_min = para_alpha)
                # _, x, _ = linear_model.lasso_path(self.dictionary.T, y)
                reg = linear_model.LassoLars(alpha=0.01)
                reg.fit(self.dictionary, y)
                x = reg.coef_
            elif self.transform_algorithm == "GramOmp":
                x = linear_model.orthogonal_mp_gram(self.dictionary.T.dot(self.dictionary), self.dictionary.T.dot(y))
            elif self.transform_algorithm == "lasso_cd":
                x = sparse_encode(self.dictionary, y.T, n_nonzero_coefs=self.n_nonzero_coefs, max_iter=500)
            # e = 0.5 * np.linalg.norm(y - np.dot(self.dictionary, x))  + alpha * np.linalg.norm(x ,ord =1) 
            e =  np.linalg.norm(y - np.dot(self.dictionary, x))  # + np.linalg.norm(x ,ord =1) 
            # https://cloud.tencent.com/developer/article/1406349
            e_loss.append(e)           
            if e < self.tol:
                break
            self._update_dict(y, self.dictionary, x)
        self.loss = np.array(e_loss)
        # self.sparsecode = linear_model.orthogonal_mp(self.dictionary, y, n_nonzero_coefs=self.n_nonzero_coefs)
        self.sparsecode = x
        self.components_ = self.dictionary
        # self.random_state_ = random_state
        # return self.dictionary, self.sparsecode
        return self







